import argparse
import numpy as np
import imageio
import torch
from tqdm import tqdm
import time
import scipy
import scipy.io
import scipy.misc

from lib.model_test import D2Net
from lib.utils import preprocess_image
from lib.pyramid import process_multiscale

import cv2
import matplotlib.pyplot as plt
import os
from sys import exit, argv
from PIL import Image
from skimage.feature import match_descriptors
from skimage.measure import ransac
from skimage.transform import ProjectiveTransform, AffineTransform
import pydegensac


def extract(image, model, model2, device, multiscale=False, preprocessing='caffe'):
	resized_image = image

	fact_i = image.shape[0] / resized_image.shape[0]
	fact_j = image.shape[1] / resized_image.shape[1]

	input_image = preprocess_image(
		resized_image,
		preprocessing=preprocessing
	)
	with torch.no_grad():
		if multiscale:
			keypoints, scores, descriptors = process_multiscale(
				torch.tensor(
					input_image[np.newaxis, :, :, :].astype(np.float32),
					device=device
				),
				model
			)
		else:
			keypoints, scores, descriptors = process_multiscale(
				torch.tensor(
					input_image[np.newaxis, :, :, :].astype(np.float32),
					device=device
				),
				model,
				scales=[1]
			)

			keypoints2, scores2, descriptors2 = process_multiscale(
				torch.tensor(
					input_image[np.newaxis, :, :, :].astype(np.float32),
					device=device
				),
				model2,
				scales=[1]
			)

	keypoints[:, 0] *= fact_i
	keypoints[:, 1] *= fact_j
	keypoints = keypoints[:, [1, 0, 2]]

	keypoints2[:, 0] *= fact_i
	keypoints2[:, 1] *= fact_j
	keypoints2 = keypoints2[:, [1, 0, 2]]

	keypoints_b = np.concatenate([keypoints, keypoints2], 0)
	scores_b = np.concatenate([scores, scores2], 0)
	descriptors_b = np.concatenate([descriptors, descriptors2], 0)

	feat = {}
	feat['keypoints'] = keypoints_b
	feat['scores'] = scores_b
	feat['descriptors'] = descriptors_b

	return feat


def extractSingle(image, model, device):

	with torch.no_grad():
		keypoints, scores, descriptors = process_multiscale(
			image.to(device).unsqueeze(0),
			model,
			scales=[1]
		)

	keypoints = keypoints[:, [1, 0, 2]]

	feat = {}
	feat['keypoints'] = keypoints
	feat['scores'] = scores
	feat['descriptors'] = descriptors

	return feat

def cv2D2netMatching(image1, image2, feat1, feat2, matcher="BF"):
	if(matcher == "BF"):

		t0 = time.time()
		bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
		matches = bf.match(feat1['descriptors'], feat2['descriptors'])
		matches = sorted(matches, key=lambda x:x.distance)
		t1 = time.time()
		print("Time to extract matches: ", t1-t0)

		print("Number of raw matches:", len(matches))

		match1 = [m.queryIdx for m in matches]
		match2 = [m.trainIdx for m in matches]

		keypoints_left = feat1['keypoints'][match1, : 2]
		keypoints_right = feat2['keypoints'][match2, : 2]

		np.random.seed(0)

		t0 = time.time()

		### Ransac ###
		# model, inliers = ransac(
		# 	(keypoints_left, keypoints_right),
		# 	AffineTransform, min_samples=4,
		# 	residual_threshold=8, max_trials=10000
		# )
		####

		H, inliers = pydegensac.findHomography(keypoints_left, keypoints_right, 8.0, 0.99, 10000)

		t1 = time.time()
		print("Time for ransac: ", t1-t0)

		n_inliers = np.sum(inliers)
		print('Number of inliers: %d.' % n_inliers)

		inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_left[inliers]]
		inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_right[inliers]]
		placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)]

		image3 = cv2.drawMatches(image1, inlier_keypoints_left, image2, inlier_keypoints_right, placeholder_matches, None)

		#### Visualization ####
		# plt.figure(figsize=(20, 20))
		# plt.imshow(image3)
		# plt.axis('off')
		# plt.show()

		src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
		dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2)

		return src_pts, dst_pts


def siftMatching(img1, img2, HFile1, HFile2, device):

	H1 = np.load(HFile1)
	H2 = np.load(HFile2)

	img1 = Image.open(img1)
	rgbFile1 = img1
	if(img1.mode != 'RGB'):
		img1 = img1.convert('RGB')
	img1 = np.array(img1)
	img1 = cv2.warpPerspective(img1, H1, dsize=(400,400))

	#### Visualization ####
	# cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
	# cv2.waitKey(0)

	img2 = Image.open(img2)
	rgbFile2 = img2
	if(img2.mode != 'RGB'):
		img2 = img2.convert('RGB')
	img2 = np.array(img2)
	img2 = cv2.warpPerspective(img2, H2, dsize=(400,400))

	#### Visualization ####
	# cv2.imshow("Image", cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
	# cv2.waitKey(0)

	# surf = cv2.xfeatures2d.SURF_create(100) # SURF
	surf = cv2.xfeatures2d.SIFT_create()

	kp1, des1 = surf.detectAndCompute(img1, None)
	kp2, des2 = surf.detectAndCompute(img2, None)

	matches = mnn_matcher(
			torch.from_numpy(des1).float().to(device=device),
			torch.from_numpy(des2).float().to(device=device)
		)

	src_pts = np.float32([ kp1[m[0]].pt for m in matches ]).reshape(-1, 2)
	dst_pts = np.float32([ kp2[m[1]].pt for m in matches ]).reshape(-1, 2)

	if(src_pts.shape[0] < 5 or dst_pts.shape[0] < 5):
		return [], []

	H, inliers = pydegensac.findHomography(src_pts, dst_pts, 8.0, 0.99, 10000)

	n_inliers = np.sum(inliers)

	inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]]
	inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]]
	placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)]

	#### Visualization ####
	# image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None)
	# cv2.imshow('Matches', image3)
	# cv2.waitKey()

	src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
	dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
	orgSrc, orgDst = orgKeypoints(src_pts, dst_pts, H1, H2)
	
	return orgSrc, orgDst


def getTopImg(image, H, imgSize=400):
	warpImg = cv2.warpPerspective(image, H, (imgSize, imgSize))
	cv2.imshow("Image", cv2.cvtColor(warpImg, cv2.COLOR_BGR2RGB))
	cv2.waitKey(0)

	return warpImg


def orgKeypoints(src_pts, dst_pts, H1, H2):
	ones = np.ones((src_pts.shape[0], 1))

	src_pts = np.hstack((src_pts, ones))
	dst_pts = np.hstack((dst_pts, ones))

	orgSrc = np.linalg.inv(H1) @ src_pts.T
	orgDst = np.linalg.inv(H2) @ dst_pts.T

	orgSrc = orgSrc/orgSrc[2, :]
	orgDst = orgDst/orgDst[2, :]

	orgSrc = np.asarray(orgSrc)[0:2, :]
	orgDst = np.asarray(orgDst)[0:2, :]

	return orgSrc, orgDst


def drawOrg(image1, image2, orgSrc, orgDst):
	img1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
	img2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)

	for i in range(orgSrc.shape[1]):
		im1 = cv2.circle(img1, (int(orgSrc[0, i]), int(orgSrc[1, i])), 3, (0, 0, 255), 1)
	for i in range(orgDst.shape[1]):
		im2 = cv2.circle(img2, (int(orgDst[0, i]), int(orgDst[1, i])), 3, (0, 0, 255), 1)

	im4 = cv2.hconcat([im1, im2])
	for i in range(orgSrc.shape[1]):
		im4 = cv2.line(im4, (int(orgSrc[0, i]), int(orgSrc[1, i])), (int(orgDst[0, i]) +  im1.shape[1], int(orgDst[1, i])), (0, 255, 0), 1)
	im4 = cv2.cvtColor(im4, cv2.COLOR_BGR2RGB)
	cv2.imshow("Image", im4)
	cv2.waitKey(0)



def getPerspKeypoints(rgbFile1, rgbFile2, HFile1, HFile2, model, device):
	if HFile1 is None:
		igp1, img1 = read_and_process_image(rgbFile1, H=None)
	else:
		H1 = np.load(HFile1)
		igp1, img1 = read_and_process_image(rgbFile1, H=H1)

	c,h,w = igp1.shape

	if HFile2 is None:
		igp2, img2 = read_and_process_image(rgbFile2, H=None)
	else:
		H2 = np.load(HFile2)
		igp2, img2 = read_and_process_image(rgbFile2, H=H2)

	feat1 = extractSingle(igp1, model, device)
	feat2 = extractSingle(igp2, model, device)

	matches = mnn_matcher(
			torch.from_numpy(feat1['descriptors']).to(device=device),
			torch.from_numpy(feat2['descriptors']).to(device=device),
		)
	pos_a = feat1["keypoints"][matches[:, 0], : 2]
	pos_b = feat2["keypoints"][matches[:, 1], : 2]

	H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
	pos_a = pos_a[inliers]
	pos_b = pos_b[inliers]

	inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a]
	inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b]
	placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))]

	image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0])
	image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)

	#### Visualization ####
	# cv2.imshow('Matches', image3)
	# cv2.waitKey()

	orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
	drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) # Reproject matches to perspective View

	return orgSrc, orgDst


###### Ensemble
def read_and_process_image(img_path, resize=None, H=None, h=None, w=None, preprocessing='caffe'):
	img1 = Image.open(img_path)
	if resize:
		img1 = img1.resize(resize)
	if(img1.mode != 'RGB'):
		img1 = img1.convert('RGB')
	img1 = np.array(img1)
	if H is not None:
		img1 = cv2.warpPerspective(img1, H, dsize=(400, 400))
		# cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
		# cv2.waitKey(0)
	igp1 = torch.from_numpy(preprocess_image(img1, preprocessing=preprocessing).astype(np.float32))
	return igp1, img1

def mnn_matcher_scorer(descriptors_a, descriptors_b, k=np.inf):
	device = descriptors_a.device
	sim = descriptors_a @ descriptors_b.t()
	val1, nn12 = torch.max(sim, dim=1)
	val2, nn21 = torch.max(sim, dim=0)
	ids1 = torch.arange(0, sim.shape[0], device=device)
	mask = (ids1 == nn21[nn12])
	matches = torch.stack([ids1[mask], nn12[mask]]).t()
	remaining_matches_dist = val1[mask]
	return matches, remaining_matches_dist

def mnn_matcher(descriptors_a, descriptors_b):
	device = descriptors_a.device
	sim = descriptors_a @ descriptors_b.t()
	nn12 = torch.max(sim, dim=1)[1]
	nn21 = torch.max(sim, dim=0)[1]
	ids1 = torch.arange(0, sim.shape[0], device=device)
	mask = (ids1 == nn21[nn12])
	matches = torch.stack([ids1[mask], nn12[mask]])
	return matches.t().data.cpu().numpy()

def apply_ransac(kp1, kp2):
	model, inliers = ransac(
		(kp1, kp2),
		AffineTransform, min_samples=4,
		residual_threshold=8, max_trials=10000
	)
	return kp1[inliers], kp2[inliers], inliers

def getPerspKeypoints2(model1, model2, rgbFile1, rgbFile2, HFile1, HFile2, device):
	if HFile1 is None:
		igp1, img1 = read_and_process_image(rgbFile1, H=None)
	else:
		H1 = np.load(HFile1)
		igp1, img1 = read_and_process_image(rgbFile1, H=H1)

	c,h,w = igp1.shape

	if HFile2 is None:
		igp2, img2 = read_and_process_image(rgbFile2, H=None)
	else:
		H2 = np.load(HFile2)
		igp2, img2 = read_and_process_image(rgbFile2, H=H2)

	with torch.no_grad():
		keypoints_a1, scores_a1, descriptors_a1 = process_multiscale(
			igp1.to(device).unsqueeze(0),
			model1,
			scales=[1]
		)
		keypoints_a1 = keypoints_a1[:, [1, 0, 2]]

		keypoints_a2, scores_a2, descriptors_a2 = process_multiscale(
			igp1.to(device).unsqueeze(0),
			model2,
			scales=[1]
		)
		keypoints_a2 = keypoints_a2[:, [1, 0, 2]]

		keypoints_b1, scores_b1, descriptors_b1 = process_multiscale(
			igp2.to(device).unsqueeze(0),
			model1,
			scales=[1]
		)
		keypoints_b1 = keypoints_b1[:, [1, 0, 2]]

		keypoints_b2, scores_b2, descriptors_b2 = process_multiscale(
			igp2.to(device).unsqueeze(0),
			model2,
			scales=[1]
		)
		keypoints_b2 = keypoints_b2[:, [1, 0, 2]]

	# calculating matches for both models
	matches1, dist_1 = mnn_matcher_scorer(
		torch.from_numpy(descriptors_a1).to(device=device),
		torch.from_numpy(descriptors_b1).to(device=device),
#                 len(matches1)
	)
	matches2, dist_2 = mnn_matcher_scorer(
		torch.from_numpy(descriptors_a2).to(device=device),
		torch.from_numpy(descriptors_b2).to(device=device),
#                 len(matches1)
	)

	full_matches = torch.cat([matches1, matches2])
	full_dist = torch.cat([dist_1, dist_2])
	assert len(full_dist)==(len(dist_1)+len(dist_2)), "something wrong"

	k_final = len(full_dist)//2
	# k_final = len(full_dist)
	# k_final = max(len(dist_1), len(dist_2))
	top_k_mask = torch.topk(full_dist, k=k_final)[1]
	first = []
	second = []

	for valid_id in top_k_mask:
		if valid_id<len(dist_1):
			first.append(valid_id)
		else:
			second.append(valid_id-len(dist_1))
	# final_matches = full_matches[top_k_mask]

	matches1 = matches1[torch.tensor(first, device=device).long()].data.cpu().numpy()
	matches2 = matches2[torch.tensor(second, device=device).long()].data.cpu().numpy()

	pos_a1 = keypoints_a1[matches1[:, 0], : 2]
	pos_b1 = keypoints_b1[matches1[:, 1], : 2]

	pos_a2 = keypoints_a2[matches2[:, 0], : 2]
	pos_b2 = keypoints_b2[matches2[:, 1], : 2]

	pos_a = np.concatenate([pos_a1, pos_a2], 0)
	pos_b = np.concatenate([pos_b1, pos_b2], 0)

	# pos_a, pos_b, inliers = apply_ransac(pos_a, pos_b)
	H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
	pos_a = pos_a[inliers]
	pos_b = pos_b[inliers]

	inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a]
	inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b]
	placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))]

	image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0])
	image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
	# cv2.imshow('Matches', image3)
	# cv2.waitKey()


	orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
	drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst)

	return orgSrc, orgDst

##### SuperPoint

def frame2tensor(frame, device):
    return torch.from_numpy(frame/255.).float()[None, None].to(device)

def read_and_process_image_superpoint(img_path, device, resize=None, H=None, h=None, w=None):
    img1 = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img1 = np.array(img1)

    if H is not None:
        img1 = cv2.warpPerspective(img1, H, dsize=(400,400))

    igp1 = frame2tensor(img1, device)
    return igp1

def super_point_matcher(matcher, rgbFile1, rgbFile2, HFile1, HFile2, device):

	H1 = np.load(HFile1)
	H2 = np.load(HFile2)
	# igp1 = read_and_process_image_superpoint(rgbFile1, device, H=None)#, resize=(400, 400))
	igp1 = read_and_process_image_superpoint(rgbFile1, device, H=H1)#, resize=(400, 400))
	# c,h,w = igp1.shape
	# igp2 = read_and_process_image_superpoint(rgbFile2, device, H=None)#, resize=(400, 400))
	igp2 = read_and_process_image_superpoint(rgbFile2, device, H=H2)#, resize=(400, 400))

	with torch.no_grad():
		pred = matcher({'image0': igp1, 'image1': igp2})
		pred = {k: v[0].cpu().numpy() for k, v in pred.items()}

		keypoints_a, descriptors_a, scores_a = pred['keypoints0'], pred['descriptors0'].T, pred['scores0']
		keypoints_b, descriptors_b, scores_b = pred['keypoints1'], pred['descriptors1'].T, pred['scores1']
		matches1 = pred['matches0']
		matches2 = pred['matches1']

		matches1_mask = matches1!=-1
		matches2_mask = matches2!=-1
		indcs = np.arange(0,len(matches1))

		mat_indcs = indcs[matches1_mask]
		final_matches1 = matches1[matches1_mask]
		matches_sg = np.stack((mat_indcs, final_matches1), 1)

	matches_nn = mnn_matcher(
		torch.from_numpy(descriptors_a).to(device=device),
		torch.from_numpy(descriptors_b).to(device=device)
	)

	pos_a = keypoints_a[matches_sg[:, 0], : 2]
	pos_b = keypoints_b[matches_sg[:, 1], : 2]

	# H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
	# pos_a = pos_a[inliers]
	# pos_b = pos_b[inliers]

	orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
	drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst)
	return orgSrc, orgDst


if __name__ == '__main__':
	WEIGHTS = '/home/udit/udit/d2-net/models/d2_kinal_ipr.pth'
	#WEIGHTS = '/home/udit/d2-net/models/d2_tf.pth'
	srcR = argv[1]
	trgR = argv[2]
	srcH = argv[3]
	trgH = argv[4]

	orgSrc, orgDst = getPerspKeypoints(srcR, trgR, srcH, trgH, WEIGHTS, ('gpu'))
